from typing import List


class Solution:
    def checkXMatrix(self, grid: List[List[int]]) -> bool:
        w = len(grid)

        for i in range(w):
            if grid[i][i] != 0:
                grid[i][i] =0
            else:

                return False
        for i in range(w):
            if i == w-i-1:
                continue
            if grid[i][w-i-1] != 0:
                grid[i][w-i-1] = 0
            else:

                return False
        for i in range(w):
            for j in range(w):
                if grid[i][j] !=0:

                    return False
        return True




if __name__ == '__main__':
    grid = [[5,0,20],[0,5,0],[6,0,2]]
    solution = Solution()
    val = solution.checkXMatrix(grid)
    print(val)

